DeconvGradFilter

计算反卷积(Deconvolution / Transposed Convolution)算子的权重梯度,用于反向传播阶段。 该算子根据输出梯度 dy 与输入特征 x,通过 im2row + GEMM 的方式累加得到卷积核梯度 dw,并支持分组(group)计算。

\[\frac{\partial W}{\partial L} = \sum_{b=0}^{B-1} \text{Im2Row}(dY_b) \cdot X_b\]

其中每个 Group 独立计算,最终在 Group 维度上拼接。

输入:
  • dy_data - 输出特征梯度地址,形状为 [batch, out_h, out_w, out_c]

  • x_data - 输入特征地址,形状为 [batch, in_h, in_w, in_c]

  • param - 参数数组地址,用于描述反卷积计算相关参数与工作空间。
    • param[1] : in_h

    • param[2] : in_w

    • param[3] : in_c

    • param[4] : batch

    • param[5] : out_h

    • param[6] : out_w

    • param[7] : out_c

    • param[8] : kernel_h

    • param[9] : kernel_w

    • param[16] : group

    • param[17] : im2row 工作缓冲区地址

  • core_mask - 核掩码(仅适用于共享存储版本)。

输出:
  • dw_data - 权重梯度输出地址,布局为 [group, out_c/group * k_h * k_w, in_c/group]

支持平台:

FT78NE MT7004

备注

  • FT78NE 仅支持 fp 类型

  • MT7004 支持 hp, fp 类型

  • 输入与输出数据格式为 NHWC

共享存储版本:

void hp_deconvgradfilter_s(half *dy_data, half *x_data, half *dw_data, long long *param, int core_mask)
void fp_deconvgradfilter_s(float *dy_data, float *x_data, float *dw_data, long long *param, int core_mask)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <deconvgradfilter.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *dy_data = (float *)0xA0000000;
 7    float *x_data  = (float *)0xA1000000;
 8    float *dw_data = (float *)0xC0000000;
 9    long long *param = (long long *)0xA2000000;
10    int core_mask = 0xff;
11
12    fp_deconvgradfilter_s(dy_data, x_data, dw_data, param, core_mask);
13    return 0;
14}

私有存储版本:

void hp_deconvgradfilter_p(half *dy_data, half *x_data, half *dw_data, long long *param)
void fp_deconvgradfilter_p(float *dy_data, float *x_data, float *dw_data, long long *param)

C调用示例:

 1//FT78NE示例
 2#include <stdio.h>
 3#include <deconvgradfilter.h>
 4
 5int main(int argc, char* argv[]) {
 6    float *dy_data = (float *)0x10810000;   // L2空间
 7    float *x_data  = (float *)0x10820000;
 8    float *dw_data = (float *)0x10830000;
 9    long long *param = (long long *)0x10840000;
10
11    fp_deconvgradfilter_p(dy_data, x_data, dw_data, param);
12    return 0;
13}